{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/8_train_sqil.ipynb)\n",
    "# Train an Agent using Soft Q Imitation Learning\n",
    "\n",
    "Soft Q Imitation Learning ([SQIL](https://arxiv.org/abs/1905.11108)) is a simple algorithm that can be used to clone expert behavior.\n",
    "It's fundamentally a modification of the DQN algorithm. At each training step, whenever we sample a batch of data from the replay buffer,\n",
    "we also sample a batch of expert data. Expert demonstrations are assigned a reward of 1, while the agent's own transitions are assigned a reward of 0.\n",
    "This approach encourages the agent to imitate the expert's behavior, but also to avoid unfamiliar states.\n",
    "\n",
    "In this tutorial we will use the `imitation` library to train an agent using SQIL."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we need some expert trajectories in our environment (`seals/CartPole-v0`).\n",
    "Note that you can use other environments, but the action space must be discrete for this algorithm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv\n",
    "\n",
    "from imitation.data import huggingface_utils\n",
    "\n",
    "# Download some expert trajectories from the HuggingFace Datasets Hub.\n",
    "dataset = datasets.load_dataset(\"HumanCompatibleAI/ppo-CartPole-v1\")\n",
    "\n",
    "# Convert the dataset to a format usable by the imitation library.\n",
    "expert_trajectories = huggingface_utils.TrajectoryDatasetSequence(dataset[\"train\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's quickly check if the expert is any good.\n",
    "We usually should be able to reach a reward of 500, which is the maximum achievable value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.data import rollout\n",
    "\n",
    "trajectory_stats = rollout.rollout_stats(expert_trajectories)\n",
    "\n",
    "print(\n",
    "    f\"We have {trajectory_stats['n_traj']} trajectories.\"\n",
    "    f\"The average length of each trajectory is {trajectory_stats['len_mean']}.\"\n",
    "    f\"The average return of each trajectory is {trajectory_stats['return_mean']}.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After we collected our expert trajectories, it's time to set up our imitation algorithm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.algorithms import sqil\n",
    "import gymnasium as gym\n",
    "\n",
    "venv = DummyVecEnv([lambda: gym.make(\"CartPole-v1\")])\n",
    "sqil_trainer = sqil.SQIL(\n",
    "    venv=venv,\n",
    "    demonstrations=expert_trajectories,\n",
    "    policy=\"MlpPolicy\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see the untrained policy only gets poor rewards:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "\n",
    "reward_before_training, _ = evaluate_policy(sqil_trainer.policy, venv, 10)\n",
    "print(f\"Reward before training: {reward_before_training}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After training, we can match the rewards of the expert (500):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sqil_trainer.train(\n",
    "    total_timesteps=1_000,\n",
    ")  # Note: set to 1_000_000 to obtain good results\n",
    "reward_after_training, _ = evaluate_policy(sqil_trainer.policy, venv, 10)\n",
    "print(f\"Reward after training: {reward_after_training}\")"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "bd378ce8f53beae712f05342da42c6a7612fc68b19bea03b52c7b1cdc8851b5f"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
